﻿/*
 创建日期: 2015.8.18
    创建者:张存
    邮箱:zhangcunliang@126.com
    说明:
        数据库上下文,用于操作数据的核心类,也是用户所使用的公共实现
    修改记录: 
        2015.11.6   增加生成model代码,加 partial关键字
        2015.11.9   sqlite
        2018.2.27   sqlserver 数据库支持备注的显示
        2019.5.14   重新整理命名规范
        2020.11.13  增加对属性set时的监控，增加列索引（BulkCopy 需要保持列顺序）
 */
using System;
using System.Collections.Generic;
using System.Data;
using System.Data.Common;
using ZhCun.Utils;

namespace ZhCun.CodeBuilder.Builders
{
    public abstract class BaseBuilder
    {
        public BaseBuilder(string connStr)
        {
            ConnectString = connStr;
        }
        protected string ConnectString;

        public DbConnection CreateConnectionTest()
        {
            var conn = CreateConnection();
            conn.ConnectionString = ConnectString;
            return conn;
        }

        protected abstract internal DbConnection CreateConnection();

        protected abstract DbCommand CreateDbCommand();

        protected virtual string GetCSharpDataTypeString(string dbDataType)
        {
            string r;
            switch (dbDataType.ToLower())
            {
                case "tinyint":
                case "int":
                case "integer":
                    r = "int";
                    break;
                case "smallint":
                    r = "uint";
                    break;
                case "bigint":
                    r = "long";
                    break;
                case "boolean":
                case "bit":
                    r = "bool";
                    break;
                case "uniqueidentifier":
                    r = "Guid";
                    break;
                case "smalldatetime":
                case "datetime":
                case "date":
                    r = "DateTime";
                    break;
                case "decimal":
                case "money":
                case "numeric":
                case "smallmoney":
                    r = "decimal";
                    break;
                case "float":
                case "real":
                    r = "double";
                    break;
                case "varchar":
                case "nvarchar":
                case "char":
                case "text":
                case "":
                    r = "string";
                    break;
                default:  //未知的数据类型都是string
                    r = "object";
                    break;
            }
            return r;
        }
        /// <summary>
        /// 获取列备注信息
        /// </summary>
        protected virtual Dictionary<string, string> GetColumnRemarks(string tbName)
        {
            Dictionary<string, string> remarks = new Dictionary<string, string>();
            return remarks;
        }

        protected virtual DbDataReader GetDataReader(string sql)
        {
            DbConnection conn = CreateConnection();
            using (DbCommand cmd = CreateDbCommand())
            {
                cmd.Connection = conn;
                cmd.CommandText = sql;
                conn.Open();
                DbDataReader reader = cmd.ExecuteReader(System.Data.CommandBehavior.CloseConnection);
                return reader;
            }
        }

        protected virtual string FormatTableName(string tableName)
        {
            return tableName;
        }
        /// <summary>
        /// 获得Table列表(含视图)，按sqlserver认识;
        /// </summary>
        protected internal virtual List<ModelTable> GetTableInfo()
        {
            DataTable tableInfo = null;
            using (DbConnection connObj = CreateConnection())
            {
                DataTable dt = null;
                connObj.ConnectionString = ConnectString;
                connObj.Open();
                dt = connObj.GetSchema("Tables");
                dt.DefaultView.Sort = "Table_Type,Table_Name";
                tableInfo = dt.DefaultView.ToTable();
            }
            List<ModelTable> list = new List<ModelTable>();
            for (int i = 0; i < tableInfo.Rows.Count; i++)
            {
                ModelTable model = new ModelTable();
                model.IsUse = true;
                model.TableName = tableInfo.Rows[i]["Table_Name"].ToString();
                model.TableRemark = model.TableName;
                if ("VIEW".Equals(tableInfo.Rows[i]["Table_Type"].ToString(), StringComparison.CurrentCultureIgnoreCase))
                {
                    model.TableType = TableTypeEnum.View.ToString();
                }
                else
                {
                    model.TableType = TableTypeEnum.Table.ToString();
                }
                list.Add(model);
            }
            return list;
        }
        /// <summary>
        /// 获取主键字段
        /// </summary>
        protected virtual List<string> GetPrimarykeys(string tableName)
        {
            return new List<string>();
        }
        /// <summary>
        /// 获得字段信息
        /// </summary>
        protected internal virtual List<ModelField> GetColumnInfo(string tableName)
        {
            DataTable colInfo;
            tableName = FormatTableName(tableName);
            string sql = string.Format("select * from {0}", tableName);
            var reader = GetDataReader(sql);
            using (reader)
            {
                colInfo = reader.GetSchemaTable();
                if (colInfo == null || colInfo.Rows.Count == 0) return null;
                List<ModelField> colList = new List<ModelField>();
                var remarks = GetColumnRemarks(tableName);
                //是否没有主键
                bool hasPK = false;
                bool isHide = false;
                for (int i = 0; i < colInfo.Rows.Count; i++)
                {
                    if (colInfo.Columns.Contains("IsHidden"))
                    {
                        isHide = colInfo.Rows[i]["IsHidden"] == DBNull.Value ? false : Convert.ToBoolean(colInfo.Rows[i]["IsHidden"]);
                    }
                    else
                    {
                        isHide = false;
                    }
                    if (isHide) continue;
                    ModelField model = new ModelField();
                    model.ColumnIndex = i;
                    model.ColumnName = colInfo.Rows[i]["ColumnName"].ToString(); //列名
                    //数据库中数据类型
                    if (colInfo.Columns.Contains("DataTypeName"))
                    {
                        model.ColumnDataType = colInfo.Rows[i]["DataTypeName"].ToString();
                    }
                    model.ColumnDataTypeByCSharp = colInfo.Rows[i]["DataType"].ToString(); //.net中的数据类型                    
                    model.ColumnLength = Convert.ToInt32(colInfo.Rows[i]["ColumnSize"]);
                    model.ColumnOrder = Convert.ToInt32(colInfo.Rows[i]["ColumnOrdinal"]);
                    if (remarks.ContainsKey(model.ColumnName))
                    {
                        model.ColumnRemark = remarks[model.ColumnName];
                    }
                    else
                    {
                        model.ColumnRemark = model.ColumnName;
                    }
                    //sqlite 数据库自增长列的元结构列为:IsAutoIncrement
                    //sqlserver 为: IsIdentity
                    string[] identityColumns = { "IsIdentity", "IsAutoIncrement" };
                    foreach (string colName in identityColumns)
                    {
                        if (colInfo.Columns.Contains(colName) && colInfo.Rows[i][colName] != DBNull.Value)
                        {
                            model.IsIdentity = Convert.ToBoolean(colInfo.Rows[i][colName]);
                            break;
                        }
                    }

                    model.IsNullAble = Convert.ToBoolean(colInfo.Rows[i]["AllowDBNull"]);
                    model.IsPK = colInfo.Rows[i]["IsKey"] == DBNull.Value ? false : Convert.ToBoolean(colInfo.Rows[i]["IsKey"]);
                    model.IsUse = true;
                    model.TableName = tableName;
                    if (model.IsPK) hasPK = true;
                    colList.Add(model);
                }
                //没有主键从数据库中查询，sqlserver 2008 不能从SchemaTable 中获取主键
                if (!hasPK)
                {
                    var pks = GetPrimarykeys(tableName);
                    if (pks.Count > 0)
                    {
                        pks.ForEach((pkCol) =>
                        {
                            var col = colList.Find(s => s.ColumnName == pkCol);
                            col.IsPK = true;
                            hasPK = true;
                        });
                    }
                }
                //如果没有主键
                if (!hasPK)
                {
                    //2014.8.21 如果没有设定主键,则把第一个子增长的字段设置为主键                    
                    foreach (var item in colList)
                    {
                        if (item.IsIdentity)
                        {
                            item.IsPK = true;
                            hasPK = true;
                            item.Remark = "主键?";
                            break;
                        }
                    }
                }
                //如果还没有主键
                if (!hasPK)
                {
                    //2014.8.21 如果没有设定主键,则把第一个不为空的字段设置为主键                    
                    foreach (var item in colList)
                    {
                        if (!item.IsNullAble)
                        {
                            item.IsPK = true;
                            hasPK = true;
                            item.Remark = "主键?";
                            break;
                        }
                    }
                }
                //如果还没有主键
                if (!hasPK)
                {
                    //2014.8.21 如果没有设定主键,则把第一个字段设置为主键                    
                    foreach (var item in colList)
                    {
                        item.IsPK = true;
                        hasPK = true;
                        item.Remark = "主键?";
                        break;
                    }
                }
                return colList;
            }
        }
        /// <summary>
        /// 获取存储过程列表(含函数)
        /// </summary>
        protected internal virtual List<ModelProc> GetProcInfo()
        {
            DataTable procInfo = null;
            using (DbConnection connObj = CreateConnection())
            {
                connObj.ConnectionString = ConnectString;
                connObj.Open();
                DataTable dt = null;
                dt = connObj.GetSchema("Procedures");
                dt.DefaultView.Sort = "Specific_Name";
                procInfo = dt.DefaultView.ToTable();
            }
            List<ModelProc> list = new List<ModelProc>();
            // string colName = "Specific_Name";
            for (int i = 0; i < procInfo.Rows.Count; i++)
            {
                ModelProc model = new ModelProc();
                model.ProcName = procInfo.Rows[i]["Specific_Name"].ToString();
                model.ProcRemark = string.Empty;
                model.ProcType = procInfo.Rows[i]["Routine_Type"].ToString();
                list.Add(model);
            }
            return list;
        }
        /// <summary>
        /// 获得类中字段的字符,统一格式
        /// </summary>
        protected string GetClassField(string fieldName)
        {
            return "_" + fieldName;
        }
        /// <summary>
        /// 获得类中属性的字符,统一格式
        /// </summary>
        protected string GetClassAttribute(string fieldName)
        {
            return fieldName;
            //return fieldName.Substring(0, 1).ToUpper() + fieldName.Substring(1);
        }
        /// <summary>
        /// 获得整个model实体类的代码
        /// </summary>
        protected internal string GetModelCode(string tableName, string appNameSpace, List<ModelField> fieldList)
        {
            StringPlus sp = new StringPlus();
            sp.AppendLine("/*CodeBuilder v2.0.1 by {0:yyyy-MM-dd HH:mm} */", DateTime.Now);
            sp.AppendLine("using System;");
            sp.AppendLine("using ZhCun.DbCore.Entitys;");
            sp.AppendLine();
            sp.AppendLine("namespace {0}", appNameSpace);
            sp.AppendLine("{");
            sp.AppendSpaceLine(1, "public partial class {0} : EntityBase", tableName);
            sp.AppendSpaceLine(1, "{");
            StringPlus fieldNameSP = new StringPlus();
            fieldNameSP.AppendSpaceLine(2, "#region 字段名的定义");
            foreach (var field in fieldList)
            {
                string cSharpDataTypeStr = field.ColumnDataTypeByCSharp;
                string fieldName = GetClassField(field.ColumnName);  //字段名
                string AttributeName = GetClassAttribute(field.ColumnName);
                string remark = field.ColumnRemark;
                fieldNameSP.AppendSpaceLine(2, "public const string CN{0} = \"{0}\";", AttributeName);

                if (field.IsNullAble)
                {
                    var t = Type.GetType(field.ColumnDataTypeByCSharp);
                    if (t.IsValueType)
                    {
                        cSharpDataTypeStr += "?";
                    }
                }
                sp.AppendSpaceLine(2, "private {0} {1};", cSharpDataTypeStr, fieldName);
                sp.AppendSpaceLine(2, "/// <summary>");
                sp.AppendSpaceLine(2, "/// {0}", remark);
                sp.AppendSpaceLine(2, "/// </summary>");
                //[Entity(CNId, 1, IsPrimaryKey = true, IsNotNull = true)]
                sp.AppendSpace(2, "[Entity(CN{0}, {1}", AttributeName, field.ColumnIndex);
                if (!field.IsNullAble)
                {
                    sp.Append(", true");
                }
                if (field.IsPK)
                {
                    sp.Append(", IsPrimaryKey = true");
                }
                if (field.IsIdentity)
                {
                    sp.Append(", IsIdentity = true");
                }
                sp.AppendLine(")]");
                sp.AppendSpaceLine(2, "public {0} {1}", cSharpDataTypeStr, AttributeName);
                sp.AppendSpaceLine(2, "{");
                sp.AppendSpaceLine(2 + 1, "get {{ return {0}; }}", fieldName);
                sp.AppendSpaceLine(2 + 1, "set");
                sp.AppendSpaceLine(2 + 1, "{");
                sp.AppendSpaceLine(2 + 1 + 1, "if (!OnPropertyChanged(CN{0}, {1}, value)) return;", AttributeName, fieldName);
                sp.AppendSpaceLine(2 + 1 + 1, "{0} = value;", fieldName);
                sp.AppendSpaceLine(2 + 1 + 1, "SetFieldChanged(CN{0}) ;", AttributeName);
                sp.AppendSpaceLine(2 + 1, "}");
                sp.AppendSpaceLine(2, "}");
                sp.AppendLine();
            }
            fieldNameSP.AppendSpaceLine(2, "#endregion");
            sp.AppendLine(fieldNameSP.ToString());
            sp.AppendSpaceLine(1, "}");
            sp.AppendLine("}");
            return sp.ToString();
        }
        /// <summary>
        /// 获得整个model实体类的代码
        /// </summary>
        public string GetModelCode(string tableName, string appNamespace)
        {
            List<ModelField> fieldList = GetColumnInfo(tableName);
            return GetModelCode(tableName, appNamespace, fieldList);
        }
        /// <summary>
        /// 返回当前连接的所有表名(含视图)
        /// </summary>
        public List<string> GetTableNameList()
        {
            List<string> tableNameList = new List<string>();
            List<ModelTable> tableInfoList = GetTableInfo();
            if (tableInfoList != null)
            {
                foreach (var item in tableInfoList)
                {
                    tableNameList.Add(item.TableName);
                }
            }
            return tableNameList;
        }
        /// <summary>
        /// 获得指定表的所有字段
        /// </summary>
        public List<string> GetColumnNameList(string tableName)
        {
            List<string> columnNameList = new List<string>();
            List<ModelField> columnModelList = GetColumnInfo(tableName);
            if (columnModelList != null)
            {
                foreach (var item in columnModelList)
                {
                    columnNameList.Add(item.ColumnName);
                }
            }
            return columnNameList;
        }
        /// <summary>
        /// 获取存储过程列表
        /// </summary>
        public virtual List<string> GetProceduresList()
        {
            List<string> rList = new List<string>();
            List<ModelProc> procList = GetProcInfo();
            if (procList != null)
            {
                foreach (var item in procList)
                {
                    rList.Add(item.ProcName);
                }
            }
            return rList;
        }
        /// <summary>
        /// 获得存储过程参数信息
        /// </summary>
        protected internal virtual List<ModelProcParam> GetProcParamInfo(string procName)
        {
            DataTable dt;
            using (DbConnection connObj = CreateConnection())
            {
                connObj.ConnectionString = ConnectString;
                connObj.Open();
                dt = connObj.GetSchema("ProcedureParameters");
                dt.DefaultView.Sort = "Specific_Name";
                DataTable sortDT = dt.DefaultView.ToTable();
                List<ModelProcParam> list = new List<ModelProcParam>();
                foreach (DataRow dr in sortDT.Rows)
                {
                    string dtProcName = dr["Specific_Name"].ToString();
                    if (procName.Equals(dtProcName, StringComparison.CurrentCultureIgnoreCase))
                    {
                        ModelProcParam model = new ModelProcParam();
                        string paramName = dr["Parameter_Name"].ToString(); //参数名
                        if (string.IsNullOrEmpty(paramName)) continue;
                        model.ParamName = paramName.Remove(0, 1); //删除第一个 @ 符号
                        model.ParamDataType = dr["Data_Type"].ToString();
                        model.ParamDataType4CSharp = GetCSharpDataTypeString(model.ParamDataType);

                        if (dr["CHARACTER_MAXIMUM_LENGTH"] != null && dr["CHARACTER_MAXIMUM_LENGTH"].ToString().Length > 0)
                            model.ParamLength = Convert.ToInt32(dr["CHARACTER_MAXIMUM_LENGTH"]);
                        model.ParamOutType = dr["PARAMETER_MODE"].ToString();
                        list.Add(model);
                    }
                }
                return list;
            }
        }
        /// <summary>
        /// 生成存储过程代码
        /// </summary>
        public string GetProcModelCode(string procName, string appNamespace)
        {
            var procParamList = GetProcParamInfo(procName);
            return GetProcedureModelCode(procName, appNamespace, procParamList);
        }
        /// <summary>
        /// 生成存储过程代码
        /// </summary>
        protected internal string GetProcedureModelCode(string proceName, string appNamespace, List<ModelProcParam> procParamList)
        {
            StringPlus sp = new StringPlus();
            sp.AppendLine("using System.Data;");
            sp.AppendLine("using ZhCun.DbCore.Entitys;");
            sp.AppendLine();
            sp.AppendLine("namespace {0}", appNamespace);
            sp.AppendLine("{");
            sp.AppendSpaceLine(1, "public class {0} : ProcEntityBase", proceName);
            sp.AppendSpaceLine(1, "{");
            for (int i = 0; i < procParamList.Count; i++)
            {
                string cSharpDataTypeStr = procParamList[i].ParamDataType4CSharp;
                string fieldName = GetClassField(procParamList[i].ParamName);  //字段名
                string AttributeName = GetClassAttribute(procParamList[i].ParamName);
                string remark = procParamList[i].ParamRemark;
                //sp.AppendSpaceLine(2, "private {0} {1};", cSharpDataTypeStr, fieldName);
                //sp.AppendSpaceLine(2, "/// <summary>");
                //sp.AppendSpaceLine(2, "/// {0}", remark);
                //sp.AppendSpaceLine(2, "/// </summary>");
                if (procParamList[i].ParamOutType.Contains("OUT"))
                {
                    sp.AppendSpaceLine(2, "[ProcParam(ParamDirection= ParameterDirection.Output)]");
                }
                sp.AppendSpace(2, "public {0} {1}", cSharpDataTypeStr, AttributeName);
                sp.AppendLine(" { set; get; }");

                //sp.AppendSpaceLine(2, "{");
                //sp.AppendSpaceLine(2 + 1, "get {{ return {0}; }}", fieldName);
                //sp.AppendSpaceLine(2 + 1, "set {{ {0} = value; }}", fieldName);
                //sp.AppendSpaceLine(2, "}");
                sp.AppendLine();
            }
            sp.AppendSpaceLine(1, "}");
            sp.AppendLine("}");
            return sp.ToString();
        }
    }
}